"""The tests for the Restore component."""

from collections.abc import Coroutine
from datetime import datetime, timedelta
import logging
from typing import Any
from unittest.mock import Mock, patch

from homeassistant.const import EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP
from homeassistant.core import CoreState, HomeAssistant, State
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.reload import async_get_platform_without_config_entry
from homeassistant.helpers.restore_state import (
    DATA_RESTORE_STATE,
    STORAGE_KEY,
    RestoreEntity,
    RestoreStateData,
    StoredState,
    async_get,
    async_load,
)
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.util import dt as dt_util

from tests.common import (
    MockEntityPlatform,
    MockModule,
    MockPlatform,
    async_fire_time_changed,
    json_round_trip,
    mock_integration,
    mock_platform,
)

_LOGGER = logging.getLogger(__name__)
DOMAIN = "test_domain"
PLATFORM = "test_platform"


async def test_caching_data(hass: HomeAssistant) -> None:
    """Test that we cache data."""
    now = dt_util.utcnow()
    stored_states = [
        StoredState(State("input_boolean.b0", "on"), None, now),
        StoredState(State("input_boolean.b1", "on"), None, now),
        StoredState(State("input_boolean.b2", "on"), None, now),
    ]

    data = async_get(hass)
    await hass.async_block_till_done()
    await data.store.async_save([state.as_dict() for state in stored_states])

    # Emulate a fresh load
    hass.data.pop(DATA_RESTORE_STATE)

    with (
        patch(
            "homeassistant.helpers.restore_state.Store.async_load",
            side_effect=HomeAssistantError,
        ),
        patch("homeassistant.helpers.restore_state.Store.async_save"),
    ):
        # Failure to load should not be treated as fatal
        await async_load(hass)

    data = async_get(hass)
    assert data.last_states == {}

    # Mock that only b1 is present this run
    with patch(
        "homeassistant.helpers.restore_state.Store.async_save"
    ) as mock_write_data:
        await async_load(hass)
        await hass.async_block_till_done()

    data = async_get(hass)

    entity = RestoreEntity()
    entity.hass = hass
    entity.entity_id = "input_boolean.b1"

    # Mock that only b1 is present this run
    state = await entity.async_get_last_state()

    assert state is not None
    assert state.entity_id == "input_boolean.b1"
    assert state.state == "on"

    assert mock_write_data.called


async def test_periodic_write(hass: HomeAssistant) -> None:
    """Test that we write periodiclly but not after stop."""
    data = async_get(hass)
    await hass.async_block_till_done()
    await data.store.async_save([])

    # Emulate a fresh load
    with patch(
        "homeassistant.helpers.restore_state.Store.async_save"
    ) as mock_write_data:
        hass.data.pop(DATA_RESTORE_STATE)
        await async_load(hass)
        data = async_get(hass)

        entity = RestoreEntity()
        entity.hass = hass
        entity.entity_id = "input_boolean.b1"

        await entity.async_get_last_state()
        await hass.async_block_till_done()

    assert mock_write_data.called

    with patch(
        "homeassistant.helpers.restore_state.Store.async_save"
    ) as mock_write_data:
        async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=15))
        await hass.async_block_till_done()

    assert mock_write_data.called

    with patch(
        "homeassistant.helpers.restore_state.Store.async_save"
    ) as mock_write_data:
        hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
        await hass.async_block_till_done()

    assert mock_write_data.called

    with patch(
        "homeassistant.helpers.restore_state.Store.async_save"
    ) as mock_write_data:
        async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=30))
        await hass.async_block_till_done()

    assert not mock_write_data.called


async def test_save_persistent_states(hass: HomeAssistant) -> None:
    """Test that we cancel the currently running job, save the data, and verify the perdiodic job continues."""
    data = async_get(hass)
    await hass.async_block_till_done()
    await data.store.async_save([])

    # Emulate a fresh load
    with patch(
        "homeassistant.helpers.restore_state.Store.async_save"
    ) as mock_write_data:
        hass.data.pop(DATA_RESTORE_STATE)
        await async_load(hass)
        data = async_get(hass)

        entity = RestoreEntity()
        entity.hass = hass
        entity.entity_id = "input_boolean.b1"

        await entity.async_get_last_state()
        await hass.async_block_till_done()

    # Startup Save
    assert mock_write_data.called

    with patch(
        "homeassistant.helpers.restore_state.Store.async_save"
    ) as mock_write_data:
        async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=10))
        await hass.async_block_till_done()

    # Not quite the first interval
    assert not mock_write_data.called

    with patch(
        "homeassistant.helpers.restore_state.Store.async_save"
    ) as mock_write_data:
        await RestoreStateData.async_save_persistent_states(hass)
        await hass.async_block_till_done()

    assert mock_write_data.called

    with patch(
        "homeassistant.helpers.restore_state.Store.async_save"
    ) as mock_write_data:
        async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=20))
        await hass.async_block_till_done()
    # Verify still saving
    assert mock_write_data.called

    with patch(
        "homeassistant.helpers.restore_state.Store.async_save"
    ) as mock_write_data:
        hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
        await hass.async_block_till_done()
    # Verify normal shutdown
    assert mock_write_data.called


async def test_hass_starting(hass: HomeAssistant) -> None:
    """Test that we cache data."""
    hass.set_state(CoreState.starting)

    now = dt_util.utcnow()
    stored_states = [
        StoredState(State("input_boolean.b0", "on"), None, now),
        StoredState(State("input_boolean.b1", "on"), None, now),
        StoredState(State("input_boolean.b2", "on"), None, now),
    ]

    data = async_get(hass)
    await hass.async_block_till_done()
    await data.store.async_save([state.as_dict() for state in stored_states])

    # Emulate a fresh load
    hass.set_state(CoreState.not_running)
    hass.data.pop(DATA_RESTORE_STATE)
    await async_load(hass)
    data = async_get(hass)

    entity = RestoreEntity()
    entity.hass = hass
    entity.entity_id = "input_boolean.b1"

    all_states = hass.states.async_all()
    assert len(all_states) == 0
    hass.states.async_set("input_boolean.b1", "on")

    # Mock that only b1 is present this run
    with patch(
        "homeassistant.helpers.restore_state.Store.async_save"
    ) as mock_write_data:
        state = await entity.async_get_last_state()
        await hass.async_block_till_done()

    assert state is not None
    assert state.entity_id == "input_boolean.b1"
    assert state.state == "on"
    hass.states.async_remove("input_boolean.b1")

    # Assert that no data was written yet, since hass is still starting.
    assert not mock_write_data.called

    # Finish hass startup
    with patch(
        "homeassistant.helpers.restore_state.Store.async_save"
    ) as mock_write_data:
        hass.bus.async_fire(EVENT_HOMEASSISTANT_START)
        await hass.async_block_till_done()

    # Assert that this session states were written
    assert mock_write_data.called


async def test_dump_data(hass: HomeAssistant) -> None:
    """Test that we cache data."""
    states = [
        State("input_boolean.b0", "on"),
        State("input_boolean.b1", "on"),
        State("input_boolean.b2", "on"),
        State("input_boolean.b5", "unavailable", {"restored": True}),
    ]

    platform = MockEntityPlatform(hass, domain="input_boolean")
    entity = Entity()
    entity.hass = hass
    entity.entity_id = "input_boolean.b0"
    await platform.async_add_entities([entity])

    entity = RestoreEntity()
    entity.hass = hass
    entity.entity_id = "input_boolean.b1"
    await platform.async_add_entities([entity])

    data = async_get(hass)
    now = dt_util.utcnow()
    data.last_states = {
        "input_boolean.b0": StoredState(State("input_boolean.b0", "off"), None, now),
        "input_boolean.b1": StoredState(State("input_boolean.b1", "off"), None, now),
        "input_boolean.b2": StoredState(State("input_boolean.b2", "off"), None, now),
        "input_boolean.b3": StoredState(State("input_boolean.b3", "off"), None, now),
        "input_boolean.b4": StoredState(
            State("input_boolean.b4", "off"),
            None,
            datetime(1985, 10, 26, 1, 22, tzinfo=dt_util.UTC),
        ),
        "input_boolean.b5": StoredState(State("input_boolean.b5", "off"), None, now),
    }

    for state in states:
        hass.states.async_set(state.entity_id, state.state, state.attributes)

    with patch(
        "homeassistant.helpers.restore_state.Store.async_save"
    ) as mock_write_data:
        await data.async_dump_states()

    assert mock_write_data.called
    args = mock_write_data.mock_calls[0][1]
    written_states = args[0]

    for state in states:
        hass.states.async_remove(state.entity_id)
    # b0 should not be written, since it didn't extend RestoreEntity
    # b1 should be written, since it is present in the current run
    # b2 should not be written, since it is not registered with the helper
    # b3 should be written, since it is still not expired
    # b4 should not be written, since it is now expired
    # b5 should be written, since current state is restored by entity registry
    assert len(written_states) == 3
    state0 = json_round_trip(written_states[0])
    state1 = json_round_trip(written_states[1])
    state2 = json_round_trip(written_states[2])
    assert state0["state"]["entity_id"] == "input_boolean.b1"
    assert state0["state"]["state"] == "on"
    assert state1["state"]["entity_id"] == "input_boolean.b3"
    assert state1["state"]["state"] == "off"
    assert state2["state"]["entity_id"] == "input_boolean.b5"
    assert state2["state"]["state"] == "off"

    # Test that removed entities are not persisted
    await entity.async_remove()

    for state in states:
        hass.states.async_set(state.entity_id, state.state, state.attributes)

    with patch(
        "homeassistant.helpers.restore_state.Store.async_save"
    ) as mock_write_data:
        await data.async_dump_states()

    assert mock_write_data.called
    args = mock_write_data.mock_calls[0][1]
    written_states = args[0]
    assert len(written_states) == 2
    state0 = json_round_trip(written_states[0])
    state1 = json_round_trip(written_states[1])
    assert state0["state"]["entity_id"] == "input_boolean.b3"
    assert state0["state"]["state"] == "off"
    assert state1["state"]["entity_id"] == "input_boolean.b5"
    assert state1["state"]["state"] == "off"


async def test_dump_error(hass: HomeAssistant) -> None:
    """Test that we cache data."""
    states = [
        State("input_boolean.b0", "on"),
        State("input_boolean.b1", "on"),
        State("input_boolean.b2", "on"),
    ]

    platform = MockEntityPlatform(hass, domain="input_boolean")
    entity = Entity()
    entity.hass = hass
    entity.entity_id = "input_boolean.b0"
    await platform.async_add_entities([entity])

    entity = RestoreEntity()
    entity.hass = hass
    entity.entity_id = "input_boolean.b1"
    await platform.async_add_entities([entity])

    data = async_get(hass)

    for state in states:
        hass.states.async_set(state.entity_id, state.state, state.attributes)

    with patch(
        "homeassistant.helpers.restore_state.Store.async_save",
        side_effect=HomeAssistantError,
    ) as mock_write_data:
        await data.async_dump_states()

    assert mock_write_data.called


async def test_load_error(hass: HomeAssistant) -> None:
    """Test that we cache data."""
    entity = RestoreEntity()
    entity.hass = hass
    entity.entity_id = "input_boolean.b1"

    with patch(
        "homeassistant.helpers.storage.Store.async_load",
        side_effect=HomeAssistantError,
    ):
        state = await entity.async_get_last_state()

    assert state is None


async def test_state_saved_on_remove(hass: HomeAssistant) -> None:
    """Test that we save entity state on removal."""
    platform = MockEntityPlatform(hass, domain="input_boolean")
    entity = RestoreEntity()
    entity.hass = hass
    entity.entity_id = "input_boolean.b0"
    await platform.async_add_entities([entity])

    now = dt_util.utcnow()
    hass.states.async_set(
        "input_boolean.b0", "on", {"complicated": {"value": {1, 2, now}}}
    )

    data = async_get(hass)

    # No last states should currently be saved
    assert not data.last_states

    await entity.async_remove()

    # We should store the input boolean state when it is removed
    state = data.last_states["input_boolean.b0"].state
    assert state.state == "on"
    assert isinstance(state.attributes["complicated"]["value"], list)
    assert set(state.attributes["complicated"]["value"]) == {1, 2, now.isoformat()}


async def test_restoring_invalid_entity_id(
    hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None:
    """Test restoring invalid entity IDs."""
    entity = RestoreEntity()
    entity.hass = hass
    entity.entity_id = "test.invalid__entity_id"
    now = dt_util.utcnow().isoformat()
    hass_storage[STORAGE_KEY] = {
        "version": 1,
        "key": STORAGE_KEY,
        "data": [
            {
                "state": {
                    "entity_id": "test.invalid__entity_id",
                    "state": "off",
                    "attributes": {},
                    "last_changed": now,
                    "last_updated": now,
                    "context": {
                        "id": "3c2243ff5f30447eb12e7348cfd5b8ff",
                        "user_id": None,
                    },
                },
                "last_seen": dt_util.utcnow().isoformat(),
            }
        ],
    }

    state = await entity.async_get_last_state()
    assert state is None


async def test_restore_entity_end_to_end(
    hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None:
    """Test restoring an entity end-to-end."""
    component_setup = Mock(return_value=True)

    setup_called = []

    entity_id = "test_domain.unnamed_device"
    data = async_get(hass)
    now = dt_util.utcnow()
    data.last_states = {
        entity_id: StoredState(State(entity_id, "stored"), None, now),
    }

    class MockRestoreEntity(RestoreEntity):
        """Mock restore entity."""

        def __init__(self) -> None:
            """Initialize the mock entity."""
            self._state: str | None = None

        @property
        def state(self) -> str | None:
            """Return the state."""
            return self._state

        async def async_added_to_hass(self) -> Coroutine[Any, Any, None]:
            """Run when entity about to be added to hass."""
            await super().async_added_to_hass()
            self._state = (await self.async_get_last_state()).state

    async def async_setup_platform(
        hass: HomeAssistant,
        config: ConfigType,
        async_add_entities: AddEntitiesCallback,
        discovery_info: DiscoveryInfoType | None = None,
    ) -> None:
        """Set up the test platform."""
        async_add_entities([MockRestoreEntity()])
        setup_called.append(True)

    mock_integration(hass, MockModule(DOMAIN, setup=component_setup))
    mock_integration(hass, MockModule(PLATFORM, dependencies=[DOMAIN]))

    platform = MockPlatform(async_setup_platform=async_setup_platform)
    mock_platform(hass, f"{PLATFORM}.{DOMAIN}", platform)

    component = EntityComponent(_LOGGER, DOMAIN, hass)

    await component.async_setup({DOMAIN: {"platform": PLATFORM, "sensors": None}})
    await hass.async_block_till_done()
    assert component_setup.called

    assert f"{PLATFORM}.{DOMAIN}" in hass.config.components
    assert len(setup_called) == 1

    platform = async_get_platform_without_config_entry(hass, PLATFORM, DOMAIN)
    assert platform.platform_name == PLATFORM
    assert platform.domain == DOMAIN
    assert hass.states.get(entity_id).state == "stored"

    await data.async_dump_states()
    await hass.async_block_till_done()

    storage_data = hass_storage[STORAGE_KEY]["data"]
    assert len(storage_data) == 1
    assert storage_data[0]["state"]["entity_id"] == entity_id
    assert storage_data[0]["state"]["state"] == "stored"

    await platform.async_reset()

    assert hass.states.get(entity_id) is None

    # Make sure the entity still gets saved to restore state
    # even though the platform has been reset since it should
    # not be expired yet.
    await data.async_dump_states()
    await hass.async_block_till_done()

    storage_data = hass_storage[STORAGE_KEY]["data"]
    assert len(storage_data) == 1
    assert storage_data[0]["state"]["entity_id"] == entity_id
    assert storage_data[0]["state"]["state"] == "stored"
